{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cot import Collection\n",
    "from cot.stats import evaluation_as_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /home/kon/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "31d84d04cdbb4a65a8e360aa0a230cf5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5882967227dc4bf8bb3c05180a568281",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c851ac77c65a44daa6d63cabb4f6ad56",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0f206862593d4a6aba179e9bcc7f53a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00553b6795044eb8a5a8687e7d2e17e9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f1b4ec75693f465f8ab0361002af92bc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll = Collection.load_thoughtsource_100(\"all\",load_pregenerated_cots=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "| Name           | Train   | Valid   | Test   |\n",
       "|----------------|---------|---------|--------|\n",
       "| commonsense_qa | -       | 100     | -      |\n",
       "| med_qa         | -       | -       | 100    |\n",
       "| medmc_qa       | -       | 100     | -      |\n",
       "| open_book_qa   | -       | -       | 100    |\n",
       "| strategy_qa    | 100     | -       | -      |\n",
       "| worldtree      | -       | -       | 100    |\n",
       "\n",
       "Not loaded: ['aqua', 'asdiv', 'entailment_bank', 'gsm8k', 'mawps', 'pubmed_qa', 'qed', 'svamp']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration of the input and parameters of the language model \n",
    "config={\n",
    "    \"instruction_keys\": None,\n",
    "    \"cot_trigger_keys\": \"zhou-01\",\n",
    "    \"answer_extraction_keys\": 'auto-kojima', \n",
    "    \"author\" : \"thoughtsource\",\n",
    "    \"api_service\": \"openai_chat\",\n",
    "    \"engine\": \"gpt-3.5-turbo\", \n",
    "    \"temperature\": 0,\n",
    "    \"max_tokens\": 512,\n",
    "    \"verbose\": False,\n",
    "    \"warn\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating commonsense_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d287d485d1d4f998306351006bd9f76",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating med_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "69e6ed0ff5be469a879e4ac80e3adf56",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating medmc_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "349f61982ff7400aaa2fb0b6a0555fe3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating open_book_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b84e6c6c20c470ea6f90a9940266b7c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating strategy_qa...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0c2a7f04979c4648a6a110e012c08c00",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating worldtree...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "311f9ee8a636458c858afe8ac1025196",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.generate(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval = coll.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "# join strings from a list with underscore\n",
    "def join_strings(list_of_strings):\n",
    "    if list_of_strings is None:\n",
    "        list_of_strings = [\"None\"]\n",
    "    if isinstance(list_of_strings, str):\n",
    "        list_of_strings = [list_of_strings]\n",
    "    joined_string = \"\"\n",
    "    for string in list_of_strings:\n",
    "        joined_string += (\"-\" + str(string))\n",
    "    # delete first underscore\n",
    "    joined_string = joined_string[1:]\n",
    "    return(joined_string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4f216abb45fe4776b3b62cc6a53f16a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36f7b64de7274eb29ae787631a264cc9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a4c3e2409ec46f6a1fdcb18ebe1be06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a88c179cccd4e378ac4e2495d7fecb0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3945314066b744048550b96d99ac47c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdb6e23137034a968b99b86986e6beaf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "coll.dump(\"thoughtsource_100\" + \"_\" + config['api_service'] + \"_\" + config['engine'].replace(\"/\", \"_\") + \"_\" + join_strings(config[\"instruction_keys\"]) + \"_\" + join_strings(config[\"cot_trigger_keys\"]) + \".json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>zhou-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>gpt-3.5-turbo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>commonsense_qa</th>\n",
       "      <td>0.66</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>med_qa</th>\n",
       "      <td>0.65</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmc_qa</th>\n",
       "      <td>0.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>open_book_qa</th>\n",
       "      <td>0.81</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>strategy_qa</th>\n",
       "      <td>0.59</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>worldtree</th>\n",
       "      <td>0.92</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     zhou-01\n",
       "               gpt-3.5-turbo\n",
       "commonsense_qa          0.66\n",
       "med_qa                  0.65\n",
       "medmc_qa                0.48\n",
       "open_book_qa            0.81\n",
       "strategy_qa             0.59\n",
       "worldtree               0.92"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"5\" halign=\"left\">None</th>\n",
       "      <th colspan=\"5\" halign=\"left\">kojima-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>command-xlarge-nightly</th>\n",
       "      <th>flan-T5-xxl</th>\n",
       "      <th>gpt-3.5-turbo</th>\n",
       "      <th>text-davinci-002</th>\n",
       "      <th>text-davinci-003</th>\n",
       "      <th>command-xlarge-nightly</th>\n",
       "      <th>flan-T5-xxl</th>\n",
       "      <th>gpt-3.5-turbo</th>\n",
       "      <th>text-davinci-002</th>\n",
       "      <th>text-davinci-003</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>commonsense_qa</th>\n",
       "      <td>0.51</td>\n",
       "      <td>0.87</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.69</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.53</td>\n",
       "      <td>0.81</td>\n",
       "      <td>0.67</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.65</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>med_qa</th>\n",
       "      <td>0.23</td>\n",
       "      <td>0.32</td>\n",
       "      <td>0.58</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.43</td>\n",
       "      <td>0.31</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.59</td>\n",
       "      <td>0.46</td>\n",
       "      <td>0.43</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmc_qa</th>\n",
       "      <td>0.25</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.58</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.22</td>\n",
       "      <td>0.35</td>\n",
       "      <td>0.47</td>\n",
       "      <td>0.37</td>\n",
       "      <td>0.36</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>open_book_qa</th>\n",
       "      <td>0.59</td>\n",
       "      <td>0.82</td>\n",
       "      <td>0.77</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.38</td>\n",
       "      <td>0.79</td>\n",
       "      <td>0.77</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.67</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>strategy_qa</th>\n",
       "      <td>0.52</td>\n",
       "      <td>0.61</td>\n",
       "      <td>0.57</td>\n",
       "      <td>0.59</td>\n",
       "      <td>0.53</td>\n",
       "      <td>0.59</td>\n",
       "      <td>0.69</td>\n",
       "      <td>0.56</td>\n",
       "      <td>0.54</td>\n",
       "      <td>0.54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>worldtree</th>\n",
       "      <td>0.63</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.96</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.91</td>\n",
       "      <td>0.58</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.93</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.89</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 None                            \\\n",
       "               command-xlarge-nightly flan-T5-xxl gpt-3.5-turbo   \n",
       "commonsense_qa                   0.51        0.87          0.72   \n",
       "med_qa                           0.23        0.32          0.58   \n",
       "medmc_qa                         0.25        0.34          0.58   \n",
       "open_book_qa                     0.59        0.82          0.77   \n",
       "strategy_qa                      0.52        0.61          0.57   \n",
       "worldtree                        0.63        0.85          0.96   \n",
       "\n",
       "                                                              kojima-01  \\\n",
       "               text-davinci-002 text-davinci-003 command-xlarge-nightly   \n",
       "commonsense_qa             0.69             0.72                   0.53   \n",
       "med_qa                      NaN             0.43                   0.31   \n",
       "medmc_qa                    NaN              0.4                   0.22   \n",
       "open_book_qa                NaN              0.7                   0.38   \n",
       "strategy_qa                0.59             0.53                   0.59   \n",
       "worldtree                   NaN             0.91                   0.58   \n",
       "\n",
       "                                                                            \n",
       "               flan-T5-xxl gpt-3.5-turbo text-davinci-002 text-davinci-003  \n",
       "commonsense_qa        0.81          0.67              0.7             0.65  \n",
       "med_qa                0.34          0.59             0.46             0.43  \n",
       "medmc_qa              0.35          0.47             0.37             0.36  \n",
       "open_book_qa          0.79          0.77              NaN             0.67  \n",
       "strategy_qa           0.69          0.56             0.54             0.54  \n",
       "worldtree              0.8          0.93              NaN             0.89  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll.dump(\"thoughtsource_100_gpt-3.5-turbo_None_kojima-01-and-None_auto-kojima.json\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### merge into ts_100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.from_json(\"/home/kon/work/ThoughtSource/notebooks/thoughtsource_100_openai_text-davinci-002_None_kojima-01.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_100 = Collection.load_thoughtsource_100()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_100_merged = ts_100.merge(coll)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_100_merged.dump(\"thoughtsource_100.json\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.load_thoughtsource_100(\"all\",load_pregenerated_cots=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll.select_generated_cots(model='gpt-3.5-turbo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.select_generated_cots(author='kojima')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ca3a098652f348179bcc2f4a931a0e3b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b60113479b44496db6533b907df92b0e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b32479667c534dbe89581d351f7ae497",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f41c91e9748d4bf0a930ceeb5d525753",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02a76c12e55a45c8a9b0b98f42d2faba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6dd2163944ea429fa35f1783eaf83389",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?ex/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "eval = coll.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th colspan=\"5\" halign=\"left\">None</th>\n",
       "      <th colspan=\"5\" halign=\"left\">kojima-01</th>\n",
       "      <th>zhou-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>command-xlarge-nightly</th>\n",
       "      <th>flan-T5-xxl</th>\n",
       "      <th>gpt-3.5-turbo</th>\n",
       "      <th>text-davinci-002</th>\n",
       "      <th>text-davinci-003</th>\n",
       "      <th>command-xlarge-nightly</th>\n",
       "      <th>flan-T5-xxl</th>\n",
       "      <th>gpt-3.5-turbo</th>\n",
       "      <th>text-davinci-002</th>\n",
       "      <th>text-davinci-003</th>\n",
       "      <th>gpt-3.5-turbo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>commonsense_qa</th>\n",
       "      <td>0.51</td>\n",
       "      <td>0.87</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.53</td>\n",
       "      <td>0.81</td>\n",
       "      <td>0.67</td>\n",
       "      <td>0.62</td>\n",
       "      <td>0.65</td>\n",
       "      <td>0.66</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>med_qa</th>\n",
       "      <td>0.23</td>\n",
       "      <td>0.32</td>\n",
       "      <td>0.58</td>\n",
       "      <td>0.41</td>\n",
       "      <td>0.43</td>\n",
       "      <td>0.31</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.59</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.43</td>\n",
       "      <td>0.65</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmc_qa</th>\n",
       "      <td>0.25</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.58</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.22</td>\n",
       "      <td>0.35</td>\n",
       "      <td>0.47</td>\n",
       "      <td>0.34</td>\n",
       "      <td>0.36</td>\n",
       "      <td>0.48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>open_book_qa</th>\n",
       "      <td>0.59</td>\n",
       "      <td>0.82</td>\n",
       "      <td>0.77</td>\n",
       "      <td>0.67</td>\n",
       "      <td>0.7</td>\n",
       "      <td>0.38</td>\n",
       "      <td>0.79</td>\n",
       "      <td>0.77</td>\n",
       "      <td>0.57</td>\n",
       "      <td>0.67</td>\n",
       "      <td>0.81</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>strategy_qa</th>\n",
       "      <td>0.52</td>\n",
       "      <td>0.61</td>\n",
       "      <td>0.57</td>\n",
       "      <td>0.36</td>\n",
       "      <td>0.53</td>\n",
       "      <td>0.59</td>\n",
       "      <td>0.69</td>\n",
       "      <td>0.56</td>\n",
       "      <td>0.38</td>\n",
       "      <td>0.54</td>\n",
       "      <td>0.59</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>worldtree</th>\n",
       "      <td>0.63</td>\n",
       "      <td>0.85</td>\n",
       "      <td>0.96</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.91</td>\n",
       "      <td>0.58</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.93</td>\n",
       "      <td>0.78</td>\n",
       "      <td>0.89</td>\n",
       "      <td>0.92</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average</th>\n",
       "      <td>0.455</td>\n",
       "      <td>0.635</td>\n",
       "      <td>0.6967</td>\n",
       "      <td>0.57</td>\n",
       "      <td>0.615</td>\n",
       "      <td>0.435</td>\n",
       "      <td>0.63</td>\n",
       "      <td>0.665</td>\n",
       "      <td>0.505</td>\n",
       "      <td>0.59</td>\n",
       "      <td>0.685</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 None                            \\\n",
       "               command-xlarge-nightly flan-T5-xxl gpt-3.5-turbo   \n",
       "commonsense_qa                   0.51        0.87          0.72   \n",
       "med_qa                           0.23        0.32          0.58   \n",
       "medmc_qa                         0.25        0.34          0.58   \n",
       "open_book_qa                     0.59        0.82          0.77   \n",
       "strategy_qa                      0.52        0.61          0.57   \n",
       "worldtree                        0.63        0.85          0.96   \n",
       "Average                         0.455       0.635        0.6967   \n",
       "\n",
       "                                                              kojima-01  \\\n",
       "               text-davinci-002 text-davinci-003 command-xlarge-nightly   \n",
       "commonsense_qa             0.76             0.72                   0.53   \n",
       "med_qa                     0.41             0.43                   0.31   \n",
       "medmc_qa                   0.34              0.4                   0.22   \n",
       "open_book_qa               0.67              0.7                   0.38   \n",
       "strategy_qa                0.36             0.53                   0.59   \n",
       "worldtree                  0.88             0.91                   0.58   \n",
       "Average                    0.57            0.615                  0.435   \n",
       "\n",
       "                                                                            \\\n",
       "               flan-T5-xxl gpt-3.5-turbo text-davinci-002 text-davinci-003   \n",
       "commonsense_qa        0.81          0.67             0.62             0.65   \n",
       "med_qa                0.34          0.59             0.34             0.43   \n",
       "medmc_qa              0.35          0.47             0.34             0.36   \n",
       "open_book_qa          0.79          0.77             0.57             0.67   \n",
       "strategy_qa           0.69          0.56             0.38             0.54   \n",
       "worldtree              0.8          0.93             0.78             0.89   \n",
       "Average               0.63         0.665            0.505             0.59   \n",
       "\n",
       "                     zhou-01  \n",
       "               gpt-3.5-turbo  \n",
       "commonsense_qa          0.66  \n",
       "med_qa                  0.65  \n",
       "medmc_qa                0.48  \n",
       "open_book_qa            0.81  \n",
       "strategy_qa             0.59  \n",
       "worldtree               0.92  \n",
       "Average                0.685  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# author thoughtsource\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>kojima-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>text-davinci-002</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>commonsense_qa</th>\n",
       "      <td>0.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>med_qa</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmc_qa</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>open_book_qa</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>strategy_qa</th>\n",
       "      <td>0.5408</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>worldtree</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average</th>\n",
       "      <td>0.6204</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                      kojima-01\n",
       "               text-davinci-002\n",
       "commonsense_qa              0.7\n",
       "med_qa                      NaN\n",
       "medmc_qa                    NaN\n",
       "open_book_qa                NaN\n",
       "strategy_qa              0.5408\n",
       "worldtree                   NaN\n",
       "Average                  0.6204"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# author kojima\n",
    "evaluation_as_table(eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead tr th {\n",
       "        text-align: left;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>None</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>text-davinci-002</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>commonsense_qa</th>\n",
       "      <td>0.6869</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>med_qa</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>medmc_qa</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>open_book_qa</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>strategy_qa</th>\n",
       "      <td>0.59</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>worldtree</th>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Average</th>\n",
       "      <td>0.6384</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                           None\n",
       "               text-davinci-002\n",
       "commonsense_qa           0.6869\n",
       "med_qa                      NaN\n",
       "medmc_qa                    NaN\n",
       "open_book_qa                NaN\n",
       "strategy_qa                0.59\n",
       "worldtree                   NaN\n",
       "Average                  0.6384"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# author wei\n",
    "# evaluation_as_table(eval)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### thoughtsource_1 dataset selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coll = coll.select(split=\"all\", number_samples=1, random_samples=True, seed=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8fea24c6d2c40169c272d89bbbd1453",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "60956474953d4170bed511215b2bb385",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6493c8ede5c742758ab9edf76a2c7751",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "372e27ffdbe6466196ea7a9ef3a4ea39",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22c78559fb11492e88ec1dffe0268ba3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ee232555baab4eb9b23b11042e438151",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# coll.dump(\"thoutsource_1_selected\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
